# TODO: disable tqdm

configfile: "config/config.yaml"

rule all:
    input:
        scplus_mudata_out_fname=config["output_data"]["scplus_mdata"]

# Prepare data
if config["params_data_preparation"]["is_multiome"]:
    rule prepare_GEX_ACC_multiome:
        input:
            cisTopic_obj_fname=config["input_data"]["cisTopic_obj_fname"],
            GEX_anndata_fname=config["input_data"]["GEX_anndata_fname"]
        output:
            config["output_data"]["combined_GEX_ACC_mudata"]
        params:
            bc_transform_func=lambda wildcards: config["params_data_preparation"]["bc_transform_func"]
        shell:
            """
            scenicplus prepare_data prepare_GEX_ACC \
                --cisTopic_obj_fname {input.cisTopic_obj_fname} \
                --GEX_anndata_fname {input.GEX_anndata_fname} \
                --out_file {output} \
                --bc_transform_func {params.bc_transform_func}
            """
else:
    rule prepare_GEX_ACC_non_multiome:
        input:
            cisTopic_obj_fname=config["input_data"]["cisTopic_obj_fname"],
            GEX_anndata_fname=config["input_data"]["GEX_anndata_fname"]
        output:
            config["output_data"]["combined_GEX_ACC_mudata"]
        params:
            bc_transform_func=lambda wildcards: config["params_data_preparation"]["bc_transform_func"],
            key_to_group_by=lambda wildcards: config["params_data_preparation"]["key_to_group_by"],
            nr_cells_per_metacells=lambda wildcards: config["params_data_preparation"]["nr_cells_per_metacells"]
        shell:
            """
            scenicplus prepare_data prepare_GEX_ACC \
                --cisTopic_obj_fname {input.cisTopic_obj_fname} \
                --GEX_anndata_fname {input.GEX_anndata_fname} \
                --out_file {output} \
                --bc_transform_func {params.bc_transform_func} \
                --is_not_multiome \
                --key_to_group_by {params.key_to_group_by} \
                --nr_cells_per_metacells {params.nr_cells_per_metacells}
            """

rule motif_enrichment_cistarget:
    input:
        region_set_folder=config["input_data"]["region_set_folder"],
        ctx_db_fname=config["input_data"]["ctx_db_fname"],
        path_to_motif_annotations=config["input_data"]["path_to_motif_annotations"]
    output:
        ctx_result_fname=config["output_data"]["ctx_result_fname"],
        output_fname_ctx_html=config["output_data"]["output_fname_ctx_html"]
    params:
        temp_dir=lambda wildcards: config["params_general"]["temp_dir"],
        species=lambda wildcards: config["params_motif_enrichment"]["species"],
        fraction_overlap_w_ctx_database=lambda wildcards: config["params_motif_enrichment"]["fraction_overlap_w_ctx_database"],
        auc_threshold=lambda wildcards: config["params_motif_enrichment"]["ctx_auc_threshold"],
        nes_threshold=lambda wildcards: config["params_motif_enrichment"]["ctx_nes_threshold"],
        rank_threshold=lambda wildcards: config["params_motif_enrichment"]["ctx_rank_threshold"],
        annotation_version=lambda wildcards: config["params_motif_enrichment"]["annotation_version"],
        motif_similarity_fdr=lambda wildcards: config["params_motif_enrichment"]["motif_similarity_fdr"],
        orthologous_identity_threshold=lambda wildcards: config["params_motif_enrichment"]["orthologous_identity_threshold"],
        annotations_to_use=lambda wildcards: config["params_motif_enrichment"]["annotations_to_use"]
    threads: config["params_general"]["n_cpu"]
    shell:
        """
            scenicplus grn_inference motif_enrichment_cistarget \
                --region_set_folder {input.region_set_folder} \
                --cistarget_db_fname {input.ctx_db_fname} \
                --output_fname_cistarget_result {output.ctx_result_fname} \
                --temp_dir {params.temp_dir} \
                --species {params.species} \
                --fr_overlap_w_ctx_db {params.fraction_overlap_w_ctx_database} \
                --auc_threshold {params.auc_threshold} \
                --nes_threshold {params.nes_threshold} \
                --rank_threshold {params.rank_threshold} \
                --path_to_motif_annotations {input.path_to_motif_annotations} \
                --annotation_version {params.annotation_version} \
                --motif_similarity_fdr {params.motif_similarity_fdr} \
                --orthologous_identity_threshold {params.orthologous_identity_threshold} \
                --annotations_to_use {params.annotations_to_use} \
                --write_html \
                --output_fname_cistarget_html {output.output_fname_ctx_html} \
                --n_cpu {threads}
        """

if config["params_motif_enrichment"]["dem_balance_number_of_promoters"]:
    rule motif_enrichment_dem:
        input:
            region_set_folder=config["input_data"]["region_set_folder"],
            dem_db_fname=config["input_data"]["dem_db_fname"],
            genome_annotation_fname=config["output_data"]["genome_annotation"],
            path_to_motif_annotations=config["input_data"]["path_to_motif_annotations"]
        output:
            dem_result_fname=config["output_data"]["dem_result_fname"],
            output_fname_dem_html=config["output_data"]["output_fname_dem_html"]
        params:
            temp_dir=lambda wildcards: config["params_general"]["temp_dir"],
            species=lambda wildcards: config["params_motif_enrichment"]["species"],
            fraction_overlap_w_dem_database=lambda wildcards: config["params_motif_enrichment"]["fraction_overlap_w_dem_database"],
            max_bg_regions=lambda wildcards: config["params_motif_enrichment"]["dem_max_bg_regions"],
            promoter_space=lambda wildcards: config["params_motif_enrichment"]["dem_promoter_space"],
            adj_pval_thr=lambda wildcards: config["params_motif_enrichment"]["dem_adj_pval_thr"],
            log2fc_thr=lambda wildcards: config["params_motif_enrichment"]["dem_log2fc_thr"],
            mean_fg_thr=lambda wildcards: config["params_motif_enrichment"]["dem_mean_fg_thr"],
            motif_hit_thr=lambda wildcards: config["params_motif_enrichment"]["dem_motif_hit_thr"],
            annotation_version=lambda wildcards: config["params_motif_enrichment"]["annotation_version"],
            motif_similarity_fdr=lambda wildcards: config["params_motif_enrichment"]["motif_similarity_fdr"],
            orthologous_identity_threshold=lambda wildcards: config["params_motif_enrichment"]["orthologous_identity_threshold"],
            annotations_to_use=lambda wildcards: config["params_motif_enrichment"]["annotations_to_use"],
            seed=lambda wildcards: config["params_general"]["seed"],
            n_cpu=lambda wildcards: config["params_general"]["n_cpu"]
        threads: config["params_general"]["n_cpu"]
        shell:
            """
                scenicplus grn_inference motif_enrichment_dem \
                    --region_set_folder {input.region_set_folder} \
                    --dem_db_fname {input.dem_db_fname} \
                    --output_fname_dem_result {output.dem_result_fname} \
                    --temp_dir {params.temp_dir} \
                    --species {params.species} \
                    --fraction_overlap_w_dem_database {params.fraction_overlap_w_dem_database} \
                    --max_bg_regions {params.max_bg_regions} \
                    --genome_annotation {input.genome_annotation_fname} \
                    --balance_number_of_promoters \
                    --promoter_space {params.promoter_space} \
                    --adjpval_thr {params.adj_pval_thr} \
                    --log2fc_thr {params.log2fc_thr} \
                    --mean_fg_thr {params.mean_fg_thr} \
                    --motif_hit_thr {params.motif_hit_thr} \
                    --path_to_motif_annotations {input.path_to_motif_annotations} \
                    --annotation_version {params.annotation_version} \
                    --motif_similarity_fdr {params.motif_similarity_fdr} \
                    --orthologous_identity_threshold {params.orthologous_identity_threshold} \
                    --annotations_to_use {params.annotations_to_use} \
                    --write_html \
                    --output_fname_dem_html {output.output_fname_dem_html} \
                    --seed {params.seed} \
                    --n_cpu {threads}
            """
else:
    rule motif_enrichment_dem:
        input:
            region_set_folder=config["input_data"]["region_set_folder"],
            dem_db_fname=config["input_data"]["dem_db_fname"],
            path_to_motif_annotations=config["input_data"]["path_to_motif_annotations"]
        output:
            dem_result_fname=config["output_data"]["dem_result_fname"],
            output_fname_dem_html=config["output_data"]["output_fname_dem_html"]
        params:
            temp_dir=lambda wildcards: config["params_general"]["temp_dir"],
            species=lambda wildcards: config["params_motif_enrichment"]["species"],
            fraction_overlap_w_dem_database=lambda wildcards: config["params_motif_enrichment"]["fraction_overlap_w_dem_database"],
            max_bg_regions=lambda wildcards: config["params_motif_enrichment"]["dem_max_bg_regions"],
            promoter_space=lambda wildcards: config["params_motif_enrichment"]["dem_promoter_space"],
            adj_pval_thr=lambda wildcards: config["params_motif_enrichment"]["dem_adj_pval_thr"],
            log2fc_thr=lambda wildcards: config["params_motif_enrichment"]["dem_log2fc_thr"],
            mean_fg_thr=lambda wildcards: config["params_motif_enrichment"]["dem_mean_fg_thr"],
            motif_hit_thr=lambda wildcards: config["params_motif_enrichment"]["dem_motif_hit_thr"],
            annotation_version=lambda wildcards: config["params_motif_enrichment"]["annotation_version"],
            motif_similarity_fdr=lambda wildcards: config["params_motif_enrichment"]["motif_similarity_fdr"],
            orthologous_identity_threshold=lambda wildcards: config["params_motif_enrichment"]["orthologous_identity_threshold"],
            annotations_to_use=lambda wildcards: config["params_motif_enrichment"]["annotations_to_use"],
            seed=lambda wildcards: config["params_general"]["seed"],
            n_cpu=lambda wildcards: config["params_general"]["n_cpu"]
        threads: config["params_general"]["n_cpu"]
        shell:
            """
                scenicplus grn_inference motif_enrichment_dem \
                    --region_set_folder {input.region_set_folder} \
                    --dem_db_fname {input.dem_db_fname} \
                    --output_fname_dem_result {output.dem_result_fname} \
                    --temp_dir {params.temp_dir} \
                    --species {params.species} \
                    --fraction_overlap_w_dem_database {params.fraction_overlap_w_dem_database} \
                    --max_bg_regions {params.max_bg_regions} \
                    --adjpval_thr {params.adj_pval_thr} \
                    --log2fc_thr {params.log2fc_thr} \
                    --mean_fg_thr {params.mean_fg_thr} \
                    --motif_hit_thr {params.motif_hit_thr} \
                    --path_to_motif_annotations {input.path_to_motif_annotations} \
                    --annotation_version {params.annotation_version} \
                    --motif_similarity_fdr {params.motif_similarity_fdr} \
                    --orthologous_identity_threshold {params.orthologous_identity_threshold} \
                    --annotations_to_use {params.annotations_to_use} \
                    --write_html \
                    --output_fname_dem_html {output.output_fname_dem_html} \
                    --seed {params.seed} \
                    --n_cpu {threads}
            """

rule prepare_menr:
    input:
        dem_result_fname=config["output_data"]["dem_result_fname"],
        ctx_result_fname=config["output_data"]["ctx_result_fname"],
        multiome_mudata_fname=config["output_data"]["combined_GEX_ACC_mudata"]
    output:
        tf_names=config["output_data"]["tf_names"],
        cistromes_direct=config["output_data"]["cistromes_direct"],
        cistromes_extended=config["output_data"]["cistromes_extended"]
    params:
        direct_annotation=lambda wildcards: config["params_data_preparation"]["direct_annotation"],
        extended_annotation=lambda wildcards: config["params_data_preparation"]["extended_annotation"]
    shell:
        """
        scenicplus prepare_data prepare_menr \
            --paths_to_motif_enrichment_results {input.dem_result_fname} {input.ctx_result_fname} \
            --multiome_mudata_fname {input.multiome_mudata_fname} \
            --out_file_tf_names {output.tf_names} \
            --out_file_direct_annotation {output.cistromes_direct} \
            --out_file_extended_annotation {output.cistromes_extended} \
            --direct_annotation {params.direct_annotation} \
            --extended_annotation {params.extended_annotation}
        """

rule download_genome_annotations:
    output:
        genome_annotation=config["output_data"]["genome_annotation"],
        chromsizes=config["output_data"]["chromsizes"]
    params:
        species=lambda wildcards: config["params_data_preparation"]["species"],
        biomart_host=lambda wildcards: config["params_data_preparation"]["biomart_host"]
    shell:
        """
        scenicplus prepare_data download_genome_annotations \
            --species {params.species} \
            --biomart_host {params.biomart_host} \
            --genome_annotation_out_fname {output.genome_annotation} \
            --chromsizes_out_fname {output.chromsizes}
        """

rule get_search_space:
    input:
        multiome_mudata_fname=config["output_data"]["combined_GEX_ACC_mudata"],
        genome_annotation=config["output_data"]["genome_annotation"],
        chromsizes=config["output_data"]["chromsizes"]
    output:
        search_space=config["output_data"]["search_space"]
    params:
        upstream=lambda wildcards: config["params_data_preparation"]["search_space_upstream"],
        downstream=lambda wildcards: config["params_data_preparation"]["search_space_downstream"],
        extend_tss=lambda wildcards: config["params_data_preparation"]["search_space_extend_tss"]
    shell:
        """
        scenicplus prepare_data search_spance \
            --multiome_mudata_fname {input.multiome_mudata_fname} \
            --gene_annotation_fname {input.genome_annotation} \
            --chromsizes_fname {input.chromsizes} \
            --out_fname {output.search_space} \
            --upstream {params.upstream} \
            --downstream {params.downstream} \
            --extend_tss {params.extend_tss}
        """

rule tf_to_gene:
    input:
        multiome_mudata_fname=config["output_data"]["combined_GEX_ACC_mudata"],
        tf_names=config["output_data"]["tf_names"],
    output:
        tf_to_gene_adjacencies=config["output_data"]["tf_to_gene_adjacencies"]
    params:
        temp_dir=lambda wildcards: config["params_general"]["temp_dir"],
        method=lambda wildcards: config["params_inference"]["tf_to_gene_importance_method"],
        seed=lambda wildcards: config["params_general"]["seed"]
    threads: config["params_general"]["n_cpu"]
    shell:
        """
        scenicplus grn_inference TF_to_gene \
            --multiome_mudata_fname {input.multiome_mudata_fname} \
            --tf_names {input.tf_names} \
            --temp_dir {params.temp_dir} \
            --out_tf_to_gene_adjacencies {output.tf_to_gene_adjacencies} \
            --method {params.method} \
            --n_cpu {threads} \
            --seed {params.seed}
        """

rule region_to_gene:
    input:
        multiome_mudata_fname=config["output_data"]["combined_GEX_ACC_mudata"],
        search_space=config["output_data"]["search_space"],
    output:
        region_to_gene_adjacencies=config["output_data"]["region_to_gene_adjacencies"]
    params:
        temp_dir=lambda wildcards: config["params_general"]["temp_dir"],
        method_importance=lambda wildcards: config["params_inference"]["region_to_gene_importance_method"],
        method_correlation=lambda wildcards: config["params_inference"]["region_to_gene_correlation_method"],
    threads: config["params_general"]["n_cpu"]
    shell:
        """
        scenicplus grn_inference region_to_gene \
            --multiome_mudata_fname {input.multiome_mudata_fname} \
            --search_space_fname {input.search_space} \
            --temp_dir {params.temp_dir} \
            --out_region_to_gene_adjacencies {output.region_to_gene_adjacencies} \
            --importance_scoring_method {params.method_importance} \
            --correlation_scoring_method {params.method_correlation} \
            --n_cpu {threads}
        """

rule eGRN_direct:
    input:
        tf_to_gene_adjacencies=config["output_data"]["tf_to_gene_adjacencies"],
        region_to_gene_adjacencies=config["output_data"]["region_to_gene_adjacencies"],
        cistromes_direct=config["output_data"]["cistromes_direct"],
        ranking_db_fname=config["input_data"]["ctx_db_fname"]
    output:
        eRegulons_direct=config["output_data"]["eRegulons_direct"]
    params:
        temp_dir=lambda wildcards: config["params_general"]["temp_dir"],
        order_regions_to_genes_by=lambda wildcards: config["params_inference"]["order_regions_to_genes_by"],
        order_TFs_to_genes_by=lambda wildcards: config["params_inference"]["order_TFs_to_genes_by"],
        gsea_n_perm=lambda wildcards: config["params_inference"]["gsea_n_perm"],
        quantiles=lambda wildcards: config["params_inference"]["quantile_thresholds_region_to_gene"],
        top_n_regionTogenes_per_gene=lambda wildcards: config["params_inference"]["top_n_regionTogenes_per_gene"],
        top_n_regionTogenes_per_region=lambda wildcards: config["params_inference"]["top_n_regionTogenes_per_region"],
        min_regions_per_gene=lambda wildcards: config["params_inference"]["min_regions_per_gene"],
        rho_threshold=lambda wildcards: config["params_inference"]["rho_threshold"],
        min_target_genes=lambda wildcards: config["params_inference"]["min_target_genes"]
    threads: config["params_general"]["n_cpu"]
    shell:
        """
        scenicplus grn_inference eGRN \
            --TF_to_gene_adj_fname {input.tf_to_gene_adjacencies}\
            --region_to_gene_adj_fname {input.region_to_gene_adjacencies} \
            --cistromes_fname {input.cistromes_direct}\
            --ranking_db_fname {input.ranking_db_fname} \
            --eRegulon_out_fname {output.eRegulons_direct} \
            --temp_dir {params.temp_dir} \
            --order_regions_to_genes_by {params.order_regions_to_genes_by} \
            --order_TFs_to_genes_by {params.order_TFs_to_genes_by} \
            --gsea_n_perm {params.gsea_n_perm} \
            --quantiles {params.quantiles} \
            --top_n_regionTogenes_per_gene {params.top_n_regionTogenes_per_gene} \
            --top_n_regionTogenes_per_region {params.top_n_regionTogenes_per_region} \
            --min_regions_per_gene {params.min_regions_per_gene} \
            --rho_threshold {params.rho_threshold} \
            --min_target_genes {params.min_target_genes} \
            --n_cpu {threads}
        """

rule eGRN_extended:
    input:
        tf_to_gene_adjacencies=config["output_data"]["tf_to_gene_adjacencies"],
        region_to_gene_adjacencies=config["output_data"]["region_to_gene_adjacencies"],
        cistromes_extended=config["output_data"]["cistromes_extended"],
        ranking_db_fname=config["input_data"]["ctx_db_fname"]
    output:
        eRegulons_extended=config["output_data"]["eRegulons_extended"]
    params:
        temp_dir=lambda wildcards: config["params_general"]["temp_dir"],
        order_regions_to_genes_by=lambda wildcards: config["params_inference"]["order_regions_to_genes_by"],
        order_TFs_to_genes_by=lambda wildcards: config["params_inference"]["order_TFs_to_genes_by"],
        gsea_n_perm=lambda wildcards: config["params_inference"]["gsea_n_perm"],
        quantiles=lambda wildcards: config["params_inference"]["quantile_thresholds_region_to_gene"],
        top_n_regionTogenes_per_gene=lambda wildcards: config["params_inference"]["top_n_regionTogenes_per_gene"],
        top_n_regionTogenes_per_region=lambda wildcards: config["params_inference"]["top_n_regionTogenes_per_region"],
        min_regions_per_gene=lambda wildcards: config["params_inference"]["min_regions_per_gene"],
        rho_threshold=lambda wildcards: config["params_inference"]["rho_threshold"],
        min_target_genes=lambda wildcards: config["params_inference"]["min_target_genes"]
    threads: config["params_general"]["n_cpu"]
    shell:
        """
        scenicplus grn_inference eGRN \
            --is_extended \
            --TF_to_gene_adj_fname {input.tf_to_gene_adjacencies}\
            --region_to_gene_adj_fname {input.region_to_gene_adjacencies} \
            --cistromes_fname {input.cistromes_extended} \
            --ranking_db_fname {input.ranking_db_fname} \
            --eRegulon_out_fname {output.eRegulons_extended} \
            --temp_dir {params.temp_dir} \
            --order_regions_to_genes_by {params.order_regions_to_genes_by} \
            --order_TFs_to_genes_by {params.order_TFs_to_genes_by} \
            --gsea_n_perm {params.gsea_n_perm} \
            --quantiles {params.quantiles} \
            --top_n_regionTogenes_per_gene {params.top_n_regionTogenes_per_gene} \
            --top_n_regionTogenes_per_region {params.top_n_regionTogenes_per_region} \
            --min_regions_per_gene {params.min_regions_per_gene} \
            --rho_threshold {params.rho_threshold} \
            --min_target_genes {params.min_target_genes} \
            --n_cpu {threads}
        """
rule AUCell_direct:
    input:
        eRegulons_direct=config["output_data"]["eRegulons_direct"],
        multiome_mudata_fname=config["output_data"]["combined_GEX_ACC_mudata"],
    output:
        aucell_direct_out_fname=config["output_data"]["AUCell_direct"]
    threads: config["params_general"]["n_cpu"]
    shell:
        """
        scenicplus grn_inference AUCell \
            --eRegulon_fname {input.eRegulons_direct} \
            --multiome_mudata_fname {input.multiome_mudata_fname} \
            --aucell_out_fname {output.aucell_direct_out_fname} \
            --n_cpu {threads}
        """

rule AUCell_extended:
    input:
        eRegulons_extended=config["output_data"]["eRegulons_extended"],
        multiome_mudata_fname=config["output_data"]["combined_GEX_ACC_mudata"],
    output:
        aucell_extended_out_fname=config["output_data"]["AUCell_extended"]
    threads: config["params_general"]["n_cpu"]
    shell:
        """
        scenicplus grn_inference AUCell \
            --eRegulon_fname {input.eRegulons_extended} \
            --multiome_mudata_fname {input.multiome_mudata_fname} \
            --aucell_out_fname {output.aucell_extended_out_fname} \
            --n_cpu {threads}
        """

rule scplus_mudata:
    input:
        multiome_mudata_fname=config["output_data"]["combined_GEX_ACC_mudata"],
        aucell_direct_fname=config["output_data"]["AUCell_direct"],
        aucell_extended_fname=config["output_data"]["AUCell_extended"],
        eRegulons_direct_fname=config["output_data"]["eRegulons_direct"],
        eRegulons_extended_fname=config["output_data"]["eRegulons_extended"]
    output:
        scplus_mudata_out_fname=config["output_data"]["scplus_mdata"]
    shell:
        """
        scenicplus grn_inference create_scplus_mudata \
            --multiome_mudata_fname {input.multiome_mudata_fname} \
            --e_regulon_auc_direct_mudata_fname {input.aucell_direct_fname} \
            --e_regulon_auc_extended_mudata_fname {input.aucell_extended_fname} \
            --e_regulon_metadata_direct_fname {input.eRegulons_direct_fname} \
            --e_regulon_metadata_extended_fname {input.eRegulons_extended_fname} \
            --out_file {output.scplus_mudata_out_fname}
        """